from functools import reduce
from torch import nn


class SKConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32):
        super(SKConv, self).__init__()
        d = max(in_channels // r, L)  
        self.M = M
        self.out_channels = out_channels
        self.conv = nn.ModuleList() 
        for i in range(M):
            self.conv.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride, padding=1 + i, dilation=1 + i, groups=32, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)))
        self.global_pool = nn.AdaptiveAvgPool2d(1)  
        self.fc1 = nn.Sequential(nn.Conv2d(out_channels, d, 1, bias=False),
                                 nn.BatchNorm2d(d),
                                 nn.ReLU(inplace=True)) 
        self.fc2 = nn.Conv2d(d, out_channels * M, 1, 1, bias=False) 
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.size(0)
        out = []

        for i, conv in enumerate(self.conv):
            out.append(conv(x))

        fea_U = reduce(lambda x, y: x + y, out)  
        fea_s = self.global_pool(fea_U)
        fea_z = self.fc1(fea_s)  
        a_b = self.fc2(fea_z)  
        a_b = a_b.reshape(batch_size, self.M, self.out_channels, -1) 
        a_b = self.softmax(a_b) 
        a_b = list(a_b.chunk(self.M, dim=1))
        a_b = list(map(lambda x: x.reshape(batch_size, self.out_channels, 1, 1), a_b))
        fea_V = list(map(lambda x, y: x * y, out, a_b))  
        fea_V = reduce(lambda x, y: x + y, fea_V) 
        return fea_V


class SKBlock(nn.Module):
    expansion = 2

    def __init__(self, in_channel, out_channel, stride=1):
        super(SKBlock, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, 1, 0, bias=False),
                                   nn.BatchNorm2d(out_channel),
                                   nn.ReLU(inplace=True))
        self.conv2 = SKConv(out_channel, out_channel, stride)
        self.conv3 = nn.Sequential(nn.Conv2d(out_channel, out_channel * self.expansion, 1, 1, 0, bias=False),
                                   nn.BatchNorm2d(out_channel * self.expansion))
        self.relu = nn.ReLU(inplace=True)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channel != out_channel * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channel * self.expansion)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out += self.shortcut(x)
        return self.relu(out)
